Notes (old posts, page 2)

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 7)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.optimizers import Adam
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import SVG
from tqdm import tnrange
Using TensorFlow backend.
In [3]:
# display animation inline
plt.rc('animation', html='html5')
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 200
PRIOR_VARIANCE = 2.
LEARNING_RATE = 3e-3

Bayesian Logistic Regression (Synthetic Data)

In [7]:
w_min, w_max = -5, 5
In [8]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [9]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[11]:
(300, 300)
In [12]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [13]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([  .5, -1.])
In [14]:
X = np.vstack((x1, x2, x3))
X.shape
Out[14]:
(3, 2)
In [15]:
y1 = 1
y2 = 1
y3 = 0
In [16]:
y = np.stack((y1, y2, y3))
y.shape
Out[16]:
(3,)
In [17]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [18]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[18]:
(300, 300, 3)
In [19]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [20]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [21]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [22]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer=Adam(lr=LEARNING_RATE),
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [23]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [24]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[24]:
G 4729044328 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)4727278616 dense_1: Denseinput:output:(None, 2)(None, 10)4729044328->4727278616 4729043432 dense_2: Denseinput:output:(None, 10)(None, 20)4727278616->4729043432 4729057176 logit: Denseinput:output:(None, 20)(None, 1)4729043432->4729057176 4721891144 activation_1: Activationinput:output:(None, 1)(None, 1)4729057176->4721891144
In [25]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [26]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [27]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[27]:
[0.53143018484115601, 1.0]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [28]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [29]:
phi = inference.trainable_weights
phi
Out[29]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [30]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[30]:
G 4732353784 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)4732435816 dense_3: Denseinput:output:(None, 3)(None, 10)4732353784->4732435816 4732434640 dense_4: Denseinput:output:(None, 10)(None, 20)4732435816->4732434640 4716115280 dense_5: Denseinput:output:(None, 20)(None, 2)4732434640->4716115280
In [31]:
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
Out[31]:
(200, 2)
In [32]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
Out[32]:
(200, 2)
In [33]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [34]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [35]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [36]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [37]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
Out[37]:
(-5, 5)
Discriminator pre-training
In [38]:
def train_animate(epoch_num, batch_size=200, steps_per_epoch=15):

    for step in range(steps_per_epoch):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    ax.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

    ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

    train_info = dict(zip(discriminator.metrics_names, metrics))
    train_info['epoch'] = epoch_num
    
    props = dict(boxstyle='round', facecolor='w', alpha=0.5)

    ax.text(0.05, 0.05, 
            ('epoch: {epoch:2d}\n'
             'accuracy: {binary_accuracy:.2f}\n'        
             'loss: {loss:.2f}').format(**train_info), 
            transform=ax.transAxes, bbox=props)

    ax.set_xlabel('$w_1$')
    ax.set_ylabel('$w_2$')

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    return ax
In [39]:
FuncAnimation(fig, train_animate, frames=60, 
              interval=200, # 5 fps
              blit=False)
Out[39]:
In [40]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [41]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [42]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [43]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Evidence lower bound

In [44]:
def set_trainable(model, trainable):
    """inorder traversal"""
    model.trainable = trainable

    if isinstance(model, Model): # i.e. has layers
        for layer in model.layers:
            set_trainable(layer, trainable)
In [45]:
y_pred = K.sigmoid(K.dot(
    K.constant(w_grid),
    K.transpose(K.constant(X))))
y_pred
Out[45]:
<tf.Tensor 'Sigmoid:0' shape=(300, 300, 3) dtype=float32>
In [46]:
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
Out[46]:
<tf.Tensor 'mul_33:0' shape=(300, 300, 3) dtype=float32>
In [47]:
llhs_keras = - K.binary_crossentropy(
                   y_pred, 
                   y_true, 
                   from_logits=False)
In [48]:
sess = K.get_session()
In [49]:
np.allclose(np.sum(llhs, axis=-1),
            sess.run(K.sum(llhs_keras, axis=-1)))
Out[49]:
True
In [50]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [51]:
def make_elbo(ratio_estimator):
    
    set_trainable(ratio_estimator, False)
    
    def elbo(y_true, w_sample):
        kl_estimate = ratio_estimator(w_sample)
        y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
        log_likelihood = - K.binary_crossentropy(y_pred, y_true, 
                                                 from_logits=True)
        return K.mean(log_likelihood-kl_estimate, axis=-1)

    return elbo
In [52]:
elbo = make_elbo(ratio_estimator)
In [53]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [54]:
inference_loss = lambda y_true, w_sample: -make_elbo(ratio_estimator)(y_true, w_sample)
In [55]:
inference.compile(loss=inference_loss, 
                  optimizer=Adam(lr=LEARNING_RATE))
In [56]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [57]:
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0), 
                           axis=0, rep=BATCH_SIZE)
y_true
Out[57]:
<tf.Tensor 'concat:0' shape=(200, 3) dtype=float32>
In [58]:
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
Out[58]:
-3.8279114
In [59]:
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
 32/200 [===>..........................] - ETA: 0s
Out[59]:
3.8279113578796387

Training

In [60]:
for epoch in tnrange(200, desc='epoch'):

    set_trainable(ratio_estimator, False)

    for _ in tnrange(1, desc='generator'):

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        metrics_inference = inference.train_on_batch(eps, np.tile(y, reps=(BATCH_SIZE, 1)))

    set_trainable(discriminator, True)

    for _ in tnrange(3*50, desc='discriminator'):

        w_sample_prior = prior.rvs(size=BATCH_SIZE)

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))

        metrics_discrim = discriminator.train_on_batch(inputs, targets)

In [61]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [62]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [63]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [64]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.gray)

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 6)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import SVG
Using TensorFlow backend.
In [3]:
plt.style.use('seaborn-notebook')
# display animation inline
plt.rc('animation', html='html5')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 128
PRIOR_VARIANCE = 2.

Bayesian Logistic Regression (Synthetic Data)

In [7]:
w_min, w_max = -5, 5
In [8]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [9]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[11]:
(300, 300)
In [12]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [13]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([- .5, -1.])
In [14]:
X = np.vstack((x1, x2, x3))
X.shape
Out[14]:
(3, 2)
In [15]:
y1 = 1
y2 = 1
y3 = 0
In [16]:
y = np.stack((y1, y2, y3))
y.shape
Out[16]:
(3,)
In [17]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [18]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[18]:
(300, 300, 3)
In [19]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [20]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [21]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [22]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [23]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [24]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[24]:
G 4766144608 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)4766145224 dense_1: Denseinput:output:(None, 2)(None, 10)4766144608->4766145224 4767336208 dense_2: Denseinput:output:(None, 10)(None, 20)4766145224->4767336208 4766143992 logit: Denseinput:output:(None, 20)(None, 1)4767336208->4766143992 4766535464 activation_1: Activationinput:output:(None, 1)(None, 1)4766143992->4766535464
In [25]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [26]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [27]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[27]:
[0.72363483905792236, 0.40000000596046448]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [28]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [29]:
phi = inference.trainable_weights
phi
Out[29]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [30]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[30]:
G 4769765752 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)4769787520 dense_3: Denseinput:output:(None, 3)(None, 10)4769765752->4769787520 4769190184 dense_4: Denseinput:output:(None, 10)(None, 20)4769787520->4769190184 4769949288 dense_5: Denseinput:output:(None, 20)(None, 2)4769190184->4769949288
In [31]:
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
Out[31]:
(128, 2)
In [32]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
Out[32]:
(128, 2)
In [33]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [34]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [35]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [36]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [37]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
Discriminator pre-training
In [38]:
def train_animate(epoch_num, batch_size=128, steps_per_epoch=20):

    for step in range(steps_per_epoch):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    ax.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

    ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

    train_info = dict(zip(discriminator.metrics_names, metrics))
    train_info['epoch'] = epoch_num
    
    props = dict(boxstyle='round', facecolor='w', alpha=0.5)

    ax.text(0.05, 0.05, 
            ('epoch: {epoch:2d}\n'
             'accuracy: {binary_accuracy:.2f}\n'        
             'loss: {loss:.2f}').format(**train_info), 
            transform=ax.transAxes, bbox=props)

    ax.set_xlabel('$w_1$')
    ax.set_ylabel('$w_2$')

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    return ax
In [39]:
FuncAnimation(fig, train_animate, frames=50, 
              interval=200, # 5 fps
              blit=False)
Out[39]:
In [40]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [41]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [42]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [43]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Evidence lower bound

In [44]:
def set_trainable(model, trainable):
    """inorder traversal"""
    model.trainable = trainable

    if isinstance(model, Model): # i.e. has layers
        for layer in model.layers:
            set_trainable(layer, trainable)
In [45]:
y_pred = K.sigmoid(K.dot(
    K.constant(w_grid),
    K.transpose(K.constant(X))))
y_pred
Out[45]:
<tf.Tensor 'Sigmoid:0' shape=(300, 300, 3) dtype=float32>
In [46]:
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
Out[46]:
<tf.Tensor 'mul_33:0' shape=(300, 300, 3) dtype=float32>
In [47]:
llhs_keras = - K.binary_crossentropy(
                   y_pred, 
                   y_true, 
                   from_logits=False)
In [48]:
sess = K.get_session()
In [49]:
np.allclose(np.sum(llhs, axis=-1),
            sess.run(K.sum(llhs_keras, axis=-1)))
Out[49]:
True
In [50]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [51]:
def make_elbo(ratio_estimator):
    
    set_trainable(ratio_estimator, False)
    
    def elbo(y_true, w_sample):
        kl_estimate = ratio_estimator(w_sample)
        y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
        log_likelihood = - K.binary_crossentropy(y_pred, y_true, 
                                                 from_logits=True)
        return K.mean(log_likelihood-kl_estimate, axis=-1)

    return elbo
In [52]:
elbo = make_elbo(ratio_estimator)
In [53]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [59]:
inference_loss = lambda y_true, w_sample: -make_elbo(ratio_estimator)(y_true, w_sample)
In [60]:
inference.compile(loss=inference_loss, 
                  optimizer='adam')
In [61]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [62]:
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0), 
                           axis=0, rep=BATCH_SIZE)
y_true
Out[62]:
<tf.Tensor 'concat_1:0' shape=(128, 3) dtype=float32>
In [63]:
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
Out[63]:
-3.9920437
In [64]:
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
 32/128 [======>.......................] - ETA: 0s
Out[64]:
3.9920437335968018

Training

In [70]:
for epoch in range(3*200):

    set_trainable(ratio_estimator, False)

    for _ in range(1):

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        metrics_inference = inference.train_on_batch(eps, np.tile(y, reps=(BATCH_SIZE, 1)))

    set_trainable(discriminator, True)

    for _ in range(3*50):

        w_sample_prior = prior.rvs(size=BATCH_SIZE)

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))

        metrics_discrim = discriminator.train_on_batch(inputs, targets)
In [71]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [72]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [73]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [74]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 5)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import SVG
Using TensorFlow backend.
In [3]:
plt.style.use('seaborn-notebook')
# display animation inline
plt.rc('animation', html='html5')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 128
D_BATCH_SIZE = 128
G_BATCH_SIZE = 128
PRIOR_VARIANCE = 2.

Bayesian Logistic Regression (Synthetic Data)

In [7]:
w_min, w_max = -5, 5
In [8]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [9]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[11]:
(300, 300)
In [12]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [13]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([- .5, -1.])
In [14]:
X = np.vstack((x1, x2, x3))
X.shape
Out[14]:
(3, 2)
In [15]:
y1 = 1
y2 = 1
y3 = 0
In [16]:
y = np.stack((y1, y2, y3))
y.shape
Out[16]:
(3,)
In [17]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [18]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[18]:
(300, 300, 3)
In [19]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [20]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [21]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [22]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [23]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [24]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[24]:
G 4796441544 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)4797073168 dense_1: Denseinput:output:(None, 2)(None, 10)4796441544->4797073168 4796631920 dense_2: Denseinput:output:(None, 10)(None, 20)4797073168->4796631920 4796443560 logit: Denseinput:output:(None, 20)(None, 1)4796631920->4796443560 4795146192 activation_1: Activationinput:output:(None, 1)(None, 1)4796443560->4795146192
In [25]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [26]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [27]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[27]:
[0.63304531574249268, 0.60000002384185791]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [28]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [29]:
phi = inference.trainable_weights
phi
Out[29]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [30]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[30]:
G 4783101656 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)4797501680 dense_3: Denseinput:output:(None, 3)(None, 10)4783101656->4797501680 4797502688 dense_4: Denseinput:output:(None, 10)(None, 20)4797501680->4797502688 4799943456 dense_5: Denseinput:output:(None, 20)(None, 2)4797502688->4799943456
In [31]:
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
Out[31]:
(128, 2)
In [32]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
Out[32]:
(128, 2)
In [33]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [34]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [35]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [36]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [37]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
Discriminator pre-training
In [38]:
def train_animate(epoch_num, batch_size=128, steps_per_epoch=20):

    for step in range(steps_per_epoch):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    ax.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

    ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

    train_info = dict(zip(discriminator.metrics_names, metrics))
    train_info['epoch'] = epoch_num
    
    props = dict(boxstyle='round', facecolor='w', alpha=0.5)

    ax.text(0.05, 0.05, 
            ('epoch: {epoch:2d}\n'
             'accuracy: {binary_accuracy:.2f}\n'        
             'loss: {loss:.2f}').format(**train_info), 
            transform=ax.transAxes, bbox=props)

    ax.set_xlabel('$w_1$')
    ax.set_ylabel('$w_2$')

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    return ax
In [39]:
FuncAnimation(fig, train_animate, frames=50, 
              interval=200, # 5 fps
              blit=False)
Out[39]:
In [40]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [41]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [42]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [43]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Inference Model Training

In [44]:
y_pred = K.sigmoid(K.dot(
    K.constant(w_grid),
    K.transpose(K.constant(X))))
y_pred
Out[44]:
<tf.Tensor 'Sigmoid:0' shape=(300, 300, 3) dtype=float32>
In [45]:
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
Out[45]:
<tf.Tensor 'mul_33:0' shape=(300, 300, 3) dtype=float32>
In [46]:
llhs_keras = - K.binary_crossentropy(
                   y_pred, 
                   y_true, 
                   from_logits=False)
In [47]:
sess = K.get_session()
In [48]:
np.allclose(np.sum(llhs, axis=-1),
            sess.run(K.sum(llhs_keras, axis=-1)))
Out[48]:
True
In [49]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [50]:
def make_elbo(ratio_estimator):

    def elbo(y_true, w_sample):
        kl_estimate = ratio_estimator(w_sample)
        y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
        log_likelihood = - K.binary_crossentropy(y_pred, y_true, 
                                                 from_logits=True)
        return K.mean(log_likelihood-kl_estimate, axis=-1)

    return elbo
In [51]:
elbo = make_elbo(ratio_estimator)
In [52]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [94]:
inference.compile(loss=make_elbo(ratio_estimator), 
                  optimizer='adam')
In [95]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [96]:
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
 32/128 [======>.......................] - ETA: 0s
Out[96]:
-3.2424654364585876
In [105]:
# equiv. to use of tile above
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0), 
                           axis=0, rep=BATCH_SIZE)
y_true
Out[105]:
<tf.Tensor 'concat_10:0' shape=(128, 3) dtype=float32>
In [107]:
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
Out[107]:
-3.2424655